{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run upon export from spreadsheet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from astroquery.mast import Catalogs\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "tces_file = '/mnt/tess/labels/tces-triage-v12.csv'\n",
    "labels_file = '/mnt/tess/labels/labels-triage-v12.csv'\n",
    "\n",
    "y2_tics = pd.read_csv('/mnt/tess/labels/labels_v6.csv', header=0, low_memory=False).set_index('TIC ID')\n",
    "\n",
    "tce_table = pd.read_csv(tces_file, header=0, low_memory=False).set_index('tic_id')\n",
    "tce_table = tce_table[~tce_table.instar]\n",
    "\n",
    "tce_table = tce_table[tce_table.index.isin(y2_tics.index)]\n",
    "\n",
    "\n",
    "joined_table = tce_table\n",
    "joined_table = joined_table.reset_index()[[\n",
    "  'tic_id', 'Tmag', 'Epoc', 'Period', 'Duration',\n",
    "  'Transit_Depth', 'star_rad', 'star_mass',\n",
    "  'filename', 'Split'\n",
    "]]\n",
    "joined_table = joined_table.set_index('tic_id')\n",
    "\n",
    "\n",
    "labels_table = pd.read_csv(labels_file, header=0, low_memory=False)\n",
    "labels_table['tic_id'] = labels_table['TIC ID']\n",
    "labels_table = labels_table[labels_table.index <= 26382 - 2]\n",
    "\n",
    "disps = ['E', 'J', 'N', 'S', 'B']\n",
    "users = ['av', 'md', 'ch', 'as', 'mk', 'et', 'dm', 'td']\n",
    "for d in disps:\n",
    "  labels_table[f'disp_{d}'] = 0\n",
    "\n",
    "def set_labels(row):\n",
    "  a = ~row.isna()\n",
    "  if a['Final'] and row[\"Final\"] != 'U':\n",
    "    has_label = True\n",
    "    row[f'disp_{row[\"Final\"]}'] = 1\n",
    "  else:\n",
    "    has_label = False\n",
    "    for user in users:\n",
    "      if a[user] and row[user] and row[user] != 'U':\n",
    "        has_label = True\n",
    "        row[f'disp_{row[user]}'] += 1\n",
    "  if not has_label:\n",
    "    row['Exclude'] = 'Y'\n",
    "  return row\n",
    "labels_table = labels_table.apply(set_labels, axis=1)\n",
    "\n",
    "labels_table = labels_table[labels_table.Exclude != 'Y']\n",
    "labels_table = labels_table[['tic_id'] + [f'disp_{d}' for d in disps]]\n",
    "labels_table = labels_table.set_index('tic_id')\n",
    "\n",
    "joined_table = joined_table.join(labels_table, on='tic_id', how='inner')\n",
    "print(f'Total entries: {len(joined_table)}')\n",
    "joined_table = joined_table[\n",
    "    sum(joined_table[f'disp_{d}'] for d in disps) > 0\n",
    "]\n",
    "print(f'Total labeled entries: {len(joined_table)}')\n",
    "\n",
    "\n",
    "t_train = joined_table[joined_table['Split'] == 'train'].drop(columns=['Split'])\n",
    "t_val = joined_table[joined_table['Split'] == 'val'].drop(columns=['Split'])\n",
    "t_test = joined_table[joined_table['Split'] == 'test'].drop(columns=['Split'])\n",
    "all_table = joined_table.drop(columns=['Split'])\n",
    "\n",
    "\n",
    "print(f'Split sizes. Train: {len(t_train)}; Valid: {len(t_val)}; Test: {len(t_test)}')\n",
    "print(f'Duplicate TICs: {len(all_table.index.values) - len(set(all_table.index.values))}')\n",
    "\n",
    "t_train.to_csv('/mnt/tess/astronet/tces-v12-y2-train.csv')\n",
    "t_val.to_csv('/mnt/tess/astronet/tces-v12-y2-val.csv')\n",
    "t_test.to_csv('/mnt/tess/astronet/tces-v12-y2-test.csv')\n",
    "all_table.to_csv('/mnt/tess/astronet/tces-v12-y2-all.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.set_option('display.max_columns', None)\n",
    "t_train.sample(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t_val.sample(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t_test.sample(5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
